Title¶
Using a K-NN Classification Model to Predict the Genre of a given Song based on Danceability and Energy
Introduction¶
Today, listening to music has been more accessible than ever. Popular streaming platforms like Spotify make it easy for users to discover new music genres and receive recommendations aligned with their music preferences (Ignatius Moses Setiadi et al., 2020). Music recommendations play a crucial role in helping users find songs specifically tailored to their tastes, which often involves the process of classifying music genres via a variety of classifiers (Ignatius Moses Setiadi et al., 2020). The enjoyment of a song can depend on various factors, such as emotional impact, catchy melodies, or impactful lyrics (Khan et al., 2022). Additionally, audio features like loudness, tempo or energy can be used to classify a song’s genre, and are often used by music streaming platforms to recommend new songs to their users (Khan et al., 2022).
Based on this information, the question we want to answer with our project is: “What is the genre of a given song based on its danceability and energy values?” This is a classification question, which uses one or more variables to predict the value of a categorical variable of interest. We will be using the K-nearest neighbors algorithm to predict the genre for our chosen song. KNN is used to predict the correct class for the test data by calculating the Euclidean distance between the test data and all the training points (Taunk et al., 2019). The test data is assigned to the class that corresponds to its K nearest neighbors, with ‘K’ being the number of neighbors that must be considered (Taunk et al., 2019). The best value of K depends on the dataset and is not always the largest value, because other undesired points may get included in the neighborhood and blur the classification boundaries (Taunk et al., 2019). The dataset we will be using is “Dataset of songs in Spotify'' from Kaggle. This dataset has 22 columns titled: danceability, energy, key, loudness, mode, speechless, acousticness, instrumentalness, liveness, valence, tempo, type, id, uri, track_href, analysis_url, duration_ms, time_signature, and song_name. The full list of genres includes Trap, Techno, Techhouse, Trance, Psytrance, Dark Trap, DnB (drums and bass), Hardstyle, Underground Rap, Trap Metal, Emo, Rap, RnB, Pop and Hiphop. We will be using danceability (from 0-0.99), energy (from 0-1) and and the genres: Emo, Hardstyle, and Hiphop in our project.
Preliminary exploratory data analysis¶
library(readr)
library(repr)
library(tidyverse)
library(tidymodels)
library(ggplot2)
options(repr.matrix.max.rows = 10)
── Attaching core tidyverse packages ──────────────────────── tidyverse 2.0.0 ── ✔ dplyr 1.1.3 ✔ purrr 1.0.2 ✔ forcats 1.0.0 ✔ stringr 1.5.0 ✔ ggplot2 3.4.3 ✔ tibble 3.2.1 ✔ lubridate 1.9.2 ✔ tidyr 1.3.0 ── Conflicts ────────────────────────────────────────── tidyverse_conflicts() ── ✖ dplyr::filter() masks stats::filter() ✖ dplyr::lag() masks stats::lag() ℹ Use the conflicted package (<http://conflicted.r-lib.org/>) to force all conflicts to become errors ── Attaching packages ────────────────────────────────────── tidymodels 1.1.1 ── ✔ broom 1.0.5 ✔ rsample 1.2.0 ✔ dials 1.2.0 ✔ tune 1.1.2 ✔ infer 1.0.4 ✔ workflows 1.1.3 ✔ modeldata 1.2.0 ✔ workflowsets 1.0.1 ✔ parsnip 1.1.1 ✔ yardstick 1.2.0 ✔ recipes 1.0.8 ── Conflicts ───────────────────────────────────────── tidymodels_conflicts() ── ✖ scales::discard() masks purrr::discard() ✖ dplyr::filter() masks stats::filter() ✖ recipes::fixed() masks stringr::fixed() ✖ dplyr::lag() masks stats::lag() ✖ yardstick::spec() masks readr::spec() ✖ recipes::step() masks stats::step() • Use suppressPackageStartupMessages() to eliminate package startup messages
urlfile="https://raw.githubusercontent.com/brandonzchen/GroupProjDSCI/main/genres_v2.csv"
mydata<-read_csv(url(urlfile))
Rows: 42305 Columns: 22 ── Column specification ──────────────────────────────────────────────────────── Delimiter: "," chr (8): type, id, uri, track_href, analysis_url, genre, song_name, title dbl (14): danceability, energy, key, loudness, mode, speechiness, acousticne... ℹ Use `spec()` to retrieve the full column specification for this data. ℹ Specify the column types or set `show_col_types = FALSE` to quiet this message.
#This is the code for a summary of the information of the data
datainformation <- mydata |>
select(danceability, energy, genre) |>
filter(genre == "Emo" | genre == "hardstyle" | genre == "Hiphop") |>
group_by(genre) |>
summarise(count = n(),
mean_energy = mean(energy),
mean_danceability = mean(danceability))
datainformation
| genre | count | mean_energy | mean_danceability |
|---|---|---|---|
| <chr> | <int> | <dbl> | <dbl> |
| Emo | 1680 | 0.7611750 | 0.4936988 |
| Hiphop | 3028 | 0.6544179 | 0.6989818 |
| hardstyle | 2936 | 0.8962384 | 0.4780270 |
song_data <- mydata |>
select(danceability, energy, genre) |>
filter(genre == "Emo" | genre == "hardstyle" | genre == "Hiphop") |>
mutate(genre = as_factor(genre)) |>
drop_na()
genre_plot <- song_data |>
ggplot(aes(x = energy, y = danceability)) +
geom_point(alpha = 0.4, aes(colour = genre)) +
ggtitle("Figure 1: Scattorplot of the Genres based on Energy and Danceability") +
xlab("Energy") +
ylab("Danceability") +
labs(colour = "Genre") +
theme(text = element_text(size = 18))
options(repr.plot.width = 10, repr.plot.height = 8)
genre_plot
set.seed(2023)
song_split <- initial_split(song_data, prop = 0.75, strata = genre)
song_train <- training(song_split)
song_test <- testing(song_split)
knn_recipe <- recipe(genre ~ energy + danceability, data = song_train) |>
step_scale(all_predictors()) |>
step_center(all_predictors())
knn_spec <- nearest_neighbor(weight_func = "rectangular", neighbors = tune()) |>
set_engine("kknn") |>
set_mode("classification")
knn_vfold <- vfold_cv(song_train, v = 5, strata = genre)
k_vals <- tibble(neighbors = seq(from = 75, to = 100, by = 5))
knn_results <- workflow() |>
add_recipe(knn_recipe) |>
add_model(knn_spec) |>
tune_grid(resamples = knn_vfold, grid = k_vals) |>
collect_metrics()
accuracies <- knn_results |>
filter(.metric == "accuracy")
k_vs_accuracy_plot <- accuracies |>
ggplot(aes(x = neighbors, y = mean)) +
geom_point() +
geom_line() +
labs(x = "Neighbors", y = "Estimated Accuracy") +
ggtitle("Figure 2: Plot of Number of Neighbours vs Estimated Accuracy") +
theme(text = element_text(size = 15)) +
scale_x_continuous(breaks = seq(75, 100, by = 5))
options(repr.plot.width = 10, repr.plot.height = 8)
k_vs_accuracy_plot
set.seed(2023)
song_spec <- nearest_neighbor(weight_func = "rectangular", neighbors = 80) |>
set_engine("kknn") |>
set_mode("classification")
song_fit <- workflow() |>
add_recipe(knn_recipe) |>
add_model(song_spec) |>
fit(data = song_train)
song_test_predictions <- predict(song_fit, song_test) |>
bind_cols(song_test) |>
metrics(truth = genre, estimate = .pred_class) |>
filter(.metric == "accuracy")
song_test_predictions
| .metric | .estimator | .estimate |
|---|---|---|
| <chr> | <chr> | <dbl> |
| accuracy | multiclass | 0.7247514 |
song_recipe <- recipe(genre ~ energy + danceability, data = song_data) |>
step_scale(all_predictors()) |>
step_center(all_predictors())
song_fit_real <- workflow() |>
add_recipe(song_recipe) |>
add_model(song_spec) |>
fit(data = song_data)
new_song_1 <- tibble(energy = 0.29, danceability = 0.56)
new_song_2 <- tibble(energy = 0.889, danceability = 0.628)
new_song_3 <- tibble(energy = 0.84, danceability = 0.75)
new_song_1_predicted <- predict(song_fit_real, new_song_1)
new_song_2_predicted <- predict(song_fit_real, new_song_2)
new_song_3_predicted <- predict(song_fit_real, new_song_3)
new_song_1_predicted
new_song_2_predicted
new_song_3_predicted
| .pred_class |
|---|
| <fct> |
| Emo |
| .pred_class |
|---|
| <fct> |
| hardstyle |
| .pred_class |
|---|
| <fct> |
| Hiphop |
new_songs_predicted_plot <- song_data |>
ggplot(aes(x = energy, y = danceability)) +
geom_point(alpha = 0.4, aes(colour = genre)) +
xlab("Energy") +
ylab("Danceability") +
labs(colour = "Genre") +
theme(text = element_text(size = 12)) +
geom_point(aes(x = 0.29, y = 0.56), color = "black", size = 4) +
geom_point(aes(x = 0.889, y = 0.628), color = "purple", size = 4) +
geom_point(aes(x = 0.84, y = 0.75), color = "brown", size = 4) +
ggtitle("Figure 3: Scattorplot of Genres based on Energy and Danceability with New Song Predictions")
options(repr.plot.width = 10, repr.plot.height = 8)
new_songs_predicted_plot